import torch
import torch.utils.data 
from torch.nn import functional as F
import pytorch_lightning as pl
from models import *
from models import scorenet
from models.sde import init_sde


class CombinedModel(pl.LightningModule):
    def __init__(self, specs):
        super().__init__()
        self.specs = specs

        self.task = specs['training_task'] 
        self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = init_sde(
            self.specs['sde_mode'])

        if self.task in ('combined', 'modulation'):
            self.sdf_model = SdfModel(specs=specs) 

            feature_dim = specs["SdfModelSpecs"]["latent_dim"] 
            modulation_dim = feature_dim*3 
            latent_std = specs.get("latent_std", 0.25) 
            hidden_dims = [modulation_dim, modulation_dim, modulation_dim, modulation_dim, modulation_dim]
            self.vae_model = BetaVAE(in_channels=feature_dim*3, latent_dim=modulation_dim, condition_dim=specs['num_parts']-1, hidden_dims=hidden_dims, kl_std=latent_std)

        if self.task in ('combined', 'diffusion'):
            self.diffusion_model = DiffusionModel(model=DiffusionNet(**specs["diffusion_model_specs"]),
                                                  **specs["diffusion_specs"])
            self.scorenet = scorenet.PoseScoreNet(self.specs, self.marginal_prob_fn, pose_mode=self.specs['pose_mode'],
                                                  regression_head=self.specs['regression_head'])
            self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = init_sde(
                self.specs['sde_mode'])

        if self.task in ('score'):
            self.sdf_model = SdfModel(specs=specs)
            self.scorenet = scorenet.PoseScoreNet(self.specs, self.marginal_prob_fn, pose_mode=self.specs['pose_mode'],
                                                  regression_head=self.specs['regression_head'])
            self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = init_sde(
                self.specs['sde_mode'])

        if self.task in ('test'):
            self.sdf_model = SdfModel(specs=specs)
            feature_dim = specs["SdfModelSpecs"]["latent_dim"]  
            modulation_dim = feature_dim * 3  
            latent_std = specs.get("latent_std", 0.25)  
            hidden_dims = [modulation_dim, modulation_dim, modulation_dim, modulation_dim, modulation_dim]
            self.vae_model = BetaVAE(in_channels=feature_dim * 3, latent_dim=modulation_dim,
                                     condition_dim=specs['num_parts'] - 1, hidden_dims=hidden_dims, kl_std=latent_std)
            self.diffusion_model = DiffusionModel(model=DiffusionNet(**specs["diffusion_model_specs"]), **specs["diffusion_specs"])
            self.scorenet = scorenet.PoseScoreNet(self.specs, self.marginal_prob_fn, pose_mode=self.specs['pose_mode'],
                                                  regression_head=self.specs['regression_head'])
            self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = init_sde(
                self.specs['sde_mode'])


    def training_step(self, x, idx):
        if self.task == 'combined':
            return self.train_combined(x)
        elif self.task == 'modulation':
            return self.train_modulation(x)
        elif self.task == 'diffusion':
            return self.train_diffusion(x)
        elif self.task == 'score':
            return self.train_score(x)
        

    def configure_optimizers(self):

        if self.task == 'combined':
            params_list = [
                    { 'params': list(self.sdf_model.parameters()) + list(self.vae_model.parameters()), 'lr':self.specs['sdf_lr'] },
                    { 'params': self.diffusion_model.parameters(), 'lr':self.specs['diff_lr'] },
                    { 'params': self.scorenet.parameters(), 'lr':self.specs['score_lr'] }
                ]
        elif self.task == 'modulation':
            params_list = [
                    { 'params': list(self.sdf_model.parameters()) + list(self.vae_model.parameters()), 'lr':self.specs['sdf_lr'] }
                ]
        elif self.task == 'diffusion':
            params_list = [
                    { 'params': self.diffusion_model.parameters(), 'lr':self.specs['diff_lr'] }
                ]
        elif self.task == 'score':
            params_list = [
                    { 'params': self.scorenet.parameters(), 'lr':self.specs['score_lr'] }
                ]

        optimizer = torch.optim.Adam(params_list)
        return {
                "optimizer": optimizer,
        }

    def train_modulation(self, x):
        xyz = x['xyz']  
        gt = x['gt_sdf']  
        pc = x['point_cloud']
        atc = x['atc']
        gt_seg = x['seg']

        
        plane_features = self.sdf_model.pointnet.get_plane_features(pc)  
        original_features = torch.cat(plane_features, dim=1)  

        out = self.vae_model(original_features, atc)  
        reconstructed_plane_feature, latent = out[0], out[-1]

        
        pred_sdf, pred_atc = self.sdf_model.forward_with_plane_features(reconstructed_plane_feature, xyz)
        pred_seg = self.sdf_model.forward_with_seg_features(reconstructed_plane_feature, pc)

        
        
        try:
            vae_loss = self.vae_model.loss_function(*out, M_N=self.specs["kld_weight"])
        except:
            print("vae loss is nan at epoch {}...".format(self.current_epoch))
            return None  

        sdf_loss = F.l1_loss(pred_sdf.squeeze(), gt.squeeze(), reduction='none')
        atc_loss = F.l1_loss(pred_atc, atc.unsqueeze(dim=1).repeat(1, pred_atc.size(1), 1), reduction='none')
        sdf_loss = reduce(sdf_loss, 'b ... -> b (...)', 'mean').mean()
        atc_loss = reduce(atc_loss, 'b ... -> b (...)', 'mean').mean()

        gt_seg_onehot = F.one_hot(gt_seg.long(), num_classes=self.specs['num_parts'])
        seg_loss = self.sdf_model.compute_miou_loss(pred_seg, gt_seg_onehot)

        loss = sdf_loss + vae_loss + 0.1 * atc_loss + 0.1 * seg_loss

        loss_dict = {"sdf": sdf_loss, "atc": atc_loss, "vae": vae_loss, "seg": seg_loss}
        self.log_dict(loss_dict, prog_bar=True, enable_graph=False)

        return loss

    def train_diffusion(self, x):
        self.train()
        pc = x['point_cloud'] 
        latent = x['latent']
        cond = pc if self.specs['diffusion_model_specs']['cond'] else None
        
        diff_loss, diff_100_loss, diff_1000_loss, pred_latent, perturbed_pc = self.diffusion_model.diffusion_model_from_latent(latent, cond=cond)

        loss_dict =  {
                        "total": diff_loss,
                        "diff100": diff_100_loss, 
                        "diff1000": diff_1000_loss,
                    }
        self.log_dict(loss_dict, prog_bar=True, enable_graph=False)
        loss = diff_loss

        return loss

    def train_score(self, x):
        data = x
        camera_partial_pc = x['pts']
        ''' train score or energe without feedback'''
        pts_feat = self.scorenet.pts_encoder(camera_partial_pc)
        pred_seg = self.scorenet.seg_encoder(camera_partial_pc)
        data['pts_feat'] = pts_feat

        pose_loss = self.scorenet.collect_score_loss(self.specs, data, teacher_model=None, pts_feat_teacher=None)  
        joint_loss = self.scorenet.collect_joint_loss(self.specs, data, teacher_model=None, pts_feat_teacher=None)

        gt_seg = data['seg']
        gt_seg_onehot = F.one_hot(gt_seg.long(), num_classes=self.specs['num_parts'])
        seg_loss = self.scorenet.compute_miou_loss(pred_seg, gt_seg_onehot)

        score_loss = pose_loss + joint_loss + 0.1*seg_loss

        loss_dict = {
            "seg_loss": seg_loss,
            "score_loss": pose_loss,
            "joint_loss": joint_loss
        }
        self.log_dict(loss_dict, prog_bar=True, enable_graph=False)

        return score_loss

    def train_combined(self, x):
        xyz = x['xyz']
        gt = x['gt_sdf']
        pc = x['point_cloud']
        atc = x['atc']
        partial_pc = x['canonical_partial_pc']

        plane_features = self.sdf_model.pointnet.get_plane_features(pc)
        original_features = torch.cat(plane_features, dim=1)
        out = self.vae_model(original_features, atc)
        reconstructed_plane_feature, latent = out[0], out[-1]
        pred_sdf, pred_atc = self.sdf_model.forward_with_plane_features(reconstructed_plane_feature, xyz)

        try:
            vae_loss = self.vae_model.loss_function(*out, M_N=self.specs["kld_weight"])
        except:
            print("vae loss is nan at epoch {}...".format(self.current_epoch))
            return None
        sdf_loss = F.l1_loss(pred_sdf.squeeze(), gt.squeeze(), reduction='none')
        sdf_loss = reduce(sdf_loss, 'b ... -> b (...)', 'mean').mean()

        atc_loss = F.l1_loss(pred_atc, atc.unsqueeze(dim=1).repeat(1, pred_atc.size(1), 1), reduction='none')
        atc_loss = reduce(atc_loss, 'b ... -> b (...)', 'mean').mean()

        cond = partial_pc if self.specs['diffusion_model_specs']['cond'] else None
        latent_shape = latent[:, :768]
        diff_loss, diff_100_loss, diff_1000_loss, pred_latent, perturbed_pc = self.diffusion_model.diffusion_model_from_latent(
            latent_shape, cond=cond)

        pred_latent = torch.cat([pred_latent, latent[:, 768:]], dim=1)

        camera_partial_pc = x['pts']
        pts_feat = self.scorenet.pts_encoder(camera_partial_pc)
        pred_seg = self.scorenet.seg_encoder(camera_partial_pc)
        data['pts_feat'] = pts_feat

        pose_loss = self.scorenet.collect_score_loss(self.specs, data, teacher_model=None, pts_feat_teacher=None)
        joint_loss = self.scorenet.collect_joint_loss(self.specs, data, teacher_model=None, pts_feat_teacher=None)

        gt_seg = data['seg']
        gt_seg_onehot = F.one_hot(gt_seg.long(), num_classes=self.specs['num_parts'])
        seg_loss = self.scorenet.compute_miou_loss(pred_seg, gt_seg_onehot)

        score_loss = pose_loss + joint_loss + 0.1*seg_loss

        generated_plane_feature = self.vae_model.decode(pred_latent)
        generated_sdf_pred, _ = self.sdf_model.forward_with_plane_features(generated_plane_feature, xyz)
        generated_sdf_loss = F.l1_loss(generated_sdf_pred.squeeze(), gt.squeeze())
        loss = sdf_loss + vae_loss + diff_loss + generated_sdf_loss + 0.1 * atc_loss +score_loss

        loss_dict = {
            "total": loss,
            "sdf": sdf_loss,
            "vae": vae_loss,
            "diff": diff_loss,
            "atc": atc_loss,
            "score_loss": score_loss,
            "gensdf": generated_sdf_loss,
        }
        self.log_dict(loss_dict, prog_bar=True, enable_graph=False)

        return loss
